import pymysql
import redis
from pymysql import MySQLError
import time, datetime

"""
1、数据库设计
数据库由权限记录数据库和域归属数据库组成
(1) 权限记录数据表(tb_as)
通过权限记录数据库可以确定一个连接(源访问目的主机的目的端口)是否有权限
rno(int AUTO_INCREMENT)     src_ip(varchar)      dst_ip(varchar)      dst_port(int)        acc_auth(varchar,  yes/no/unknown)
    1001                  10.0.1.2              10.0.3.1              8001                           yes                      
(2) 域归属数据表(tb_as_ascription)
域归属数据表记录各个主机(ip)所属的自治域，通过获取自治域所属，才可以向对应的控制器的交换机下发流表项
ip(varchar PRIMARY KEY)      asc_as(varchar)
     10.0.1.2                   as1

2、读取数据表
从pyTcpdumpServer中获取源ip、目的ip、目的端口等信息，从数据库中查询对应的数据信息。
数据库：redis+MySQL
数据持久化存储在MySQL数据库中，redis作为端到MySQL的缓存加快查询速度。
"""


class FindAuth:
    def __init__(self):
        # 连接MySQL数据库
        try:
            self.conn = pymysql.connect(host='1.1.1.1', port=3306,
                                        user='root', password='111111',
                                        database='auth_data_db', charset='utf8')
        except Exception as error:
            print('连接MySQL出现问题！')
            print('失败原因：', error)
            exit()

        try:
            # 建立redis连接池
            self.conn_pool = redis.ConnectionPool(host='1.1.1.1', port=6379, db=0, decode_responses=True,
                                                  password='111111')
            # 客户端0连接数据库
            self.r0 = redis.StrictRedis(connection_pool=self.conn_pool)
        except Exception as error:
            print('连接redis出现问题！')
            print('失败原因：', error)
            exit()

    """
    由于权限是由源ip、目的ip、目的端口三个唯一确定的，所以这里的redis数据库，采用string类型表，
    即key='src_ip:dst_ip:dst_port', value='acc_auth'

    数据库采用MySQL+redis：redis作为缓存，加快查询速度，所以在查询权限数据时，先查询redis数据库中是
    否存在对应的数据记录，若有，则返回查询数据，若无，则进入MySQL中进行数据查询，并返回结果，同时向redis
    缓存更新记录。
    """

    # 查询连接权限
    def get_data(self, src_ip, dst_ip, dst_port):
        src_ip, dst_ip, dst_port = str(src_ip), str(dst_ip), str(dst_port)
        # redis string表key
        find_info = src_ip + ':' + dst_ip + ':' + dst_port
        # print(find_info)

        # 先查询redis数据库是否存在数据,如果存在数据则返回输出，若不存在则去MySQL中查询，然后再将结果更新到redis中
        result = self.r0.get(find_info)
        # 结果不为空 即redis存在查询的信息，直接输出信息,否则redis中不存在，需要查询MySQL
        if result:
            """
            每次在redis中更新或者写入数据都需要设置过期时间10分钟，然后每查询到一次就重置过期时间10分钟，
            若10分钟没有查询到这个数据，就会被清除。这样设置过期时间主要防止redis缓存数据过多，清除不常用缓存数据"""
            self.r0.expire(find_info, 600)
            # print(result)
            # 返回查询的权限结果
            return result
        else:
            with self.conn.cursor() as cursor:
                try:
                    # 执行MySQL的查询操作
                    cursor.execute('SELECT acc_auth FROM tb_as WHERE '
                                   'src_ip=%s AND dst_ip=%s AND dst_port=%s', (src_ip, dst_ip, dst_port))
                    result_sql = cursor.fetchall()
                    # print(result_sql)
                    if result_sql:
                        # 将查询结果更新写入redis数据库中
                        auth_res = result_sql[0][0]
                        # print(auth_res)
                        self.r0.set(find_info, auth_res)
                        self.r0.expire(find_info, 600)  # 设置过期时间
                        # 返回查询的权限结果
                        return auth_res
                    else:
                        return 'NULL'
                except Exception as error:
                    print(error)
                # finally:
                #     self.conn.close()

    # 获取数据库的所有记录并返回（src_ip, dst_ip, dst_port）
    # 初始化流表项时，通过这些记录下发初始化流表项
    def get_record(self):
        with self.conn.cursor() as cursor:
            try:
                # 执行MySQL的查询操作
                cursor.execute('SELECT src_ip, dst_ip, dst_port FROM tb_as')
                result_sql = cursor.fetchall()
                # print(result_sql)
                if result_sql:
                    # print(result_sql)
                    return result_sql
                else:
                    return 'NULL'
            except Exception as error:
                print(error)
            # finally:
            #     self.conn.close()

    # 查询表tb_as_ascription，获取ip对应的自治域
    def get_as(self, src_ip):
        with self.conn.cursor() as cursor:
            try:
                # 执行MySQL的查询操作
                cursor.execute('SELECT asc_as FROM tb_as_ascription WHERE ip=%s', (src_ip,))
                result_sql = cursor.fetchall()
                # print(result_sql)
                if result_sql:
                    print(result_sql[0][0])
                    return result_sql[0][0]
                else:
                    return 'NULL'
            except Exception as error:
                print(error)
            # finally:
            #     self.conn.close()


if __name__ == '__main__':
    dbs = FindAuth()
    # dbs.get_data('10.0.1.2', '10.0.3.2', '8002')
    # dbs.get_record()
    # dbs.get_as('10.0.2.1')
#
#     # dbs.post_data()
